# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
# Modified by Xingyi Zhou
import copy
import glob
import itertools
import json
import os
from collections import OrderedDict

import pycocotools.mask as mask_util
from PIL import Image

import detectron2.utils.comm as comm
from detectron2.evaluation.coco_evaluation import (
    COCOEvaluator,
    _evaluate_predictions_on_coco,
    instances_to_coco_json,
)
from fvcore.common.file_io import PathManager

from .oideval import OIDEvaluator, _evaluate_predictions_on_oid


def get_unified_evaluator(evaluator_type, dataset_name, cfg, distributed, output_dir):
    unified_label_file = cfg.MULTI_DATASET.UNIFIED_LABEL_FILE
    if evaluator_type == "coco":
        evaluator = UnifiedCOCOEvaluator(
            unified_label_file, dataset_name, cfg, distributed, output_dir
        )
    elif evaluator_type == "oid":
        evaluator = UnifiedOIDEvaluator(
            unified_label_file, dataset_name, cfg, distributed, output_dir
        )
    elif evaluator_type == "cityscapes_instance":
        evaluator = UnifiedCityscapesEvaluator(
            unified_label_file, dataset_name, cfg, distributed, output_dir
        )
    else:
        assert 0, evaluator_type
    return evaluator


def map_back_unified_id(results, map_back, reverse_id_mapping=None):
    ret = []
    for result in results:
        if result["category_id"] in map_back:
            result["category_id"] = map_back[result["category_id"]]
            if reverse_id_mapping is not None:
                result["category_id"] = reverse_id_mapping[result["category_id"]]
            ret.append(result)
    return ret


def map_back_unified_id_novel_classes(results, map_back, reverse_id_mapping=None):
    ret = []
    for result in results:
        if result["category_id"] in map_back:
            original_id_list = map_back[result["category_id"]]
            for original_id in original_id_list:
                result_copy = copy.deepcopy(result)
                result_copy["category_id"] = original_id
                if reverse_id_mapping is not None:
                    result_copy["category_id"] = reverse_id_mapping[result_copy["category_id"]]
                ret.append(result_copy)
    return ret


class UnifiedCOCOEvaluator(COCOEvaluator):
    def _eval_predictions(self, tasks, predictions):
        """
        Evaluate predictions. Fill self._results with the metrics of the tasks.
        """
        self._logger.info("Preparing results for COCO format ...")
        coco_results = list(itertools.chain(*[x["instances"] for x in predictions]))
        tasks = self._tasks or self._tasks_from_predictions(coco_results)

        # unmap the category ids for COCO
        if hasattr(self._metadata, "thing_dataset_id_to_contiguous_id") and False:
            dataset_id_to_contiguous_id = self._metadata.thing_dataset_id_to_contiguous_id
            all_contiguous_ids = list(dataset_id_to_contiguous_id.values())
            num_classes = len(all_contiguous_ids)
            assert min(all_contiguous_ids) == 0 and max(all_contiguous_ids) == num_classes - 1

            reverse_id_mapping = {v: k for k, v in dataset_id_to_contiguous_id.items()}
            for result in coco_results:
                category_id = result["category_id"]
                assert category_id < num_classes, (
                    f"A prediction has class={category_id}, "
                    f"but the dataset only has {num_classes} classes and "
                    f"predicted class id should be in [0, {num_classes - 1}]."
                )
                result["category_id"] = reverse_id_mapping[category_id]

        if self._output_dir:
            file_path = os.path.join(self._output_dir, "coco_instances_results.json")
            self._logger.info("Saving results to {}".format(file_path))
            with PathManager.open(file_path, "w") as f:
                f.write(json.dumps(coco_results))
                f.flush()

        if not self._do_evaluation and False:
            self._logger.info("Annotations are not available for evaluation.")
            return

        self._logger.info(
            "Evaluating predictions with {} COCO API...".format(
                "unofficial" if self._use_fast_impl else "official"
            )
        )
        for task in sorted(tasks):
            assert task in {"bbox", "segm", "keypoints"}, f"Got unknown task: {task}!"
            coco_eval = (
                _evaluate_predictions_on_coco(
                    self._coco_api,
                    coco_results,
                    task,
                    kpt_oks_sigmas=self._kpt_oks_sigmas,
                    use_fast_impl=self._use_fast_impl,
                    img_ids=img_ids,
                    max_dets_per_image=self._max_dets_per_image,
                )
                if len(coco_results) > 0
                else None  # cocoapi does not handle empty results very well
            )

            res = self._derive_coco_results(
                coco_eval, task, class_names=self._metadata.get("thing_classes")
            )
            self._results[task] = res


class UnifiedCityscapesEvaluator(COCOEvaluator):
    def __init__(self, unified_label_file, dataset_name, cfg, distributed, output_dir=None):
        super().__init__(dataset_name, cfg, distributed, output_dir=output_dir)
        meta_dataset_name = dataset_name[: dataset_name.find("_")]
        print("meta_dataset_name", meta_dataset_name)

        self.unified_novel_classes_eval = cfg.MULTI_DATASET.UNIFIED_NOVEL_CLASSES_EVAL
        if self.unified_novel_classes_eval:
            match_novel_classes_file = cfg.MULTI_DATASET.MATCH_NOVEL_CLASSES_FILE
            print("Loading map back from", match_novel_classes_file)
            novel_classes_map = json.load(open(match_novel_classes_file, "r"))[meta_dataset_name]
            self.map_back = {}
            for c, match in enumerate(novel_classes_map):
                for m in match:
                    self.map_back[m] = c
        else:
            unified_label_data = json.load(open(unified_label_file, "r"))
            label_map = unified_label_data["label_map"]
            label_map = label_map[meta_dataset_name]
            self.map_back = {int(v): i for i, v in enumerate(label_map)}

        self._logger.info("saving outputs to {}".format(self._output_dir))
        self._temp_dir = self._output_dir + "/cityscapes_style_eval_tmp/"
        self._logger.info(
            "Writing cityscapes results to temporary directory {} ...".format(self._temp_dir)
        )
        PathManager.mkdirs(self._temp_dir)

    def process(self, inputs, outputs):
        """
        Args:
            inputs: the inputs to a COCO model (e.g., GeneralizedRCNN).
                It is a list of dict. Each dict corresponds to an image and
                contains keys like "height", "width", "file_name", "image_id".
            outputs: the outputs of a COCO model. It is a list of dicts with key
                "instances" that contains :class:`Instances`.
        """
        for input, output in zip(inputs, outputs):
            prediction = {"image_id": input["image_id"], "file_name": input["file_name"]}

            instances = output["instances"].to(self._cpu_device)
            prediction["instances"] = instances_to_coco_json(instances, input["image_id"])
            for x in prediction["instances"]:
                x["file_name"] = input["file_name"]
            # if len(prediction['instances']) == 0:
            #     self._logger.info("No prediction for {}".format(x['file_name']))
            #     prediction['instances'] = [
            #         {'file_name': input['file_name'],
            #         ''}]
            self._predictions.append(prediction)

    def _eval_predictions(self, tasks, predictions):
        self._logger.info("Preparing results for COCO format ...")
        _unified_results = list(itertools.chain(*[x["instances"] for x in predictions]))
        all_file_names = [x["file_name"] for x in predictions]
        file_path = os.path.join(self._output_dir, "unified_instances_results.json")
        self._logger.info("Saving results to {}".format(file_path))
        with PathManager.open(file_path, "w") as f:
            f.write(json.dumps(_unified_results))
            f.flush()

        mapped = False
        thing_classes = None
        if hasattr(self._metadata, "thing_dataset_id_to_contiguous_id"):
            self._logger.info(
                "Evaluating COCO-stype cityscapes! " + "Using buildin meta to mapback IDs."
            )
            reverse_id_mapping = {
                v: k for k, v in self._metadata.thing_dataset_id_to_contiguous_id.items()
            }
            mapped = True
            thing_classes = {
                k: self._metadata.thing_classes[v]
                for k, v in self._metadata.thing_dataset_id_to_contiguous_id.items()
            }
        else:
            self._logger.info("Evaluating cityscapes! " + "Using eval script to map back IDs.")
            reverse_id_mapping = None
            thing_classes = self._metadata.thing_classes

        if self.unified_novel_classes_eval:
            coco_results = map_back_unified_id_novel_classes(
                _unified_results, self.map_back, reverse_id_mapping=reverse_id_mapping
            )
        else:
            coco_results = map_back_unified_id(
                _unified_results, self.map_back, reverse_id_mapping=reverse_id_mapping
            )

        self.write_as_cityscapes(
            coco_results,
            all_file_names,
            temp_dir=self._temp_dir,
            mapped=mapped,
            thing_classes=thing_classes,
        )

        os.environ["CITYSCAPES_DATASET"] = os.path.abspath(
            os.path.join(self._metadata.gt_dir, "..", "..")
        )
        # Load the Cityscapes eval script *after* setting the required env var,
        # since the script reads CITYSCAPES_DATASET into global variables at load time.
        import cityscapesscripts.evaluation.evalInstanceLevelSemanticLabeling as cityscapes_eval

        self._logger.info("Evaluating results under {} ...".format(self._temp_dir))
        # set some global states in cityscapes evaluation API, before evaluating
        cityscapes_eval.args.predictionPath = os.path.abspath(self._temp_dir)
        cityscapes_eval.args.predictionWalk = None
        cityscapes_eval.args.JSONOutput = False
        cityscapes_eval.args.colorized = False
        cityscapes_eval.args.gtInstancesFile = os.path.join(self._temp_dir, "gtInstances.json")

        # These lines are adopted from
        # https://github.com/mcordts/cityscapesScripts/blob/master/cityscapesscripts/evaluation/evalInstanceLevelSemanticLabeling.py # noqa
        groundTruthImgList = glob.glob(cityscapes_eval.args.groundTruthSearch)
        assert len(
            groundTruthImgList
        ), "Cannot find any ground truth images to use for evaluation. Searched for: {}".format(
            cityscapes_eval.args.groundTruthSearch
        )
        predictionImgList = []
        for gt in groundTruthImgList:
            predictionImgList.append(cityscapes_eval.getPrediction(gt, cityscapes_eval.args))
        results = cityscapes_eval.evaluateImgLists(
            predictionImgList, groundTruthImgList, cityscapes_eval.args
        )["averages"]

        ret = OrderedDict()
        ret["segm"] = {"AP": results["allAp"] * 100, "AP50": results["allAp50%"] * 100}
        return ret

    @staticmethod
    def write_as_cityscapes(
        coco_results,
        all_file_names,
        temp_dir,
        mapped=False,
        thing_classes=None,
        ext="_pred.txt",
        subfolder="",
    ):
        from cityscapesscripts.helpers.labels import name2label

        results_per_image = {x: [] for x in all_file_names}
        for x in coco_results:
            results_per_image[x["file_name"]].append(x)
        if subfolder != "":
            PathManager.mkdirs(temp_dir + "/" + subfolder)
        N = len(results_per_image)
        for i, (file_name, coco_list) in enumerate(results_per_image.items()):
            if i % (N // 10) == 0:
                print("{}%".format(i // (N // 10) * 10), end=",", flush=True)
            basename = os.path.splitext(os.path.basename(file_name))[0]
            pred_txt = os.path.join(temp_dir, basename + ext)

            num_instances = len(coco_list)
            with open(pred_txt, "w") as fout:
                for i in range(num_instances):
                    if not mapped:
                        pred_class = coco_list[i]["category_id"]
                        classes = thing_classes[pred_class]
                        class_id = name2label[classes].id
                    else:
                        class_id = coco_list[i]["category_id"]
                        classes = thing_classes[class_id]
                    score = coco_list[i]["score"]
                    mask = mask_util.decode(coco_list[i]["segmentation"])[:, :].astype("uint8")
                    # mask = output.pred_masks[i].numpy().astype("uint8")
                    if subfolder != "":
                        png_filename = os.path.join(
                            temp_dir,
                            subfolder,
                            basename + "_{}_{}.png".format(i, classes.replace(" ", "_")),
                        )
                        Image.fromarray(mask * 255).save(png_filename)
                        fout.write(
                            "{} {} {}\n".format(
                                subfolder + "/" + os.path.basename(png_filename), class_id, score
                            )
                        )

                    else:
                        png_filename = os.path.join(
                            temp_dir, basename + "_{}_{}.png".format(i, classes.replace(" ", "_"))
                        )

                        Image.fromarray(mask * 255).save(png_filename)
                        fout.write(
                            "{} {} {}\n".format(os.path.basename(png_filename), class_id, score)
                        )


class UnifiedOIDEvaluator(OIDEvaluator):
    def __init__(self, unified_label_file, dataset_name, cfg, distributed, output_dir=None):
        super().__init__(dataset_name, cfg, distributed, output_dir=output_dir)
        meta_dataset_name = dataset_name[: dataset_name.find("_")]
        print("meta_dataset_name", meta_dataset_name)
        unified_label_data = json.load(open(unified_label_file, "r"))
        label_map = unified_label_data["label_map"]
        label_map = label_map[meta_dataset_name]
        self.map_back = {int(v): i for i, v in enumerate(label_map)}
        self._logger.info("saving outputs to {}".format(self._output_dir))

    def evaluate(self):
        if self._distributed:
            comm.synchronize()
            self._predictions = comm.gather(self._predictions, dst=0)
            self._predictions = list(itertools.chain(*self._predictions))

            if not comm.is_main_process():
                return

        if len(self._predictions) == 0:
            self._logger.warning("[LVISEvaluator] Did not receive valid predictions.")
            return {}

        self._logger.info("Preparing results in the OID format ...")
        _unified_results = list(itertools.chain(*[x["instances"] for x in self._predictions]))

        if self._output_dir:
            PathManager.mkdirs(self._output_dir)

        file_path = os.path.join(self._output_dir, "unified_instances_results.json")
        self._logger.info("Saving results to {}".format(file_path))
        with PathManager.open(file_path, "w") as f:
            f.write(json.dumps(_unified_results))
            f.flush()

        self._oid_results = map_back_unified_id(_unified_results, self.map_back)

        # unmap the category ids for LVIS (from 0-indexed to 1-indexed)
        for result in self._oid_results:
            result["category_id"] += 1

        PathManager.mkdirs(self._output_dir)
        file_path = os.path.join(self._output_dir, "oid_instances_results.json")
        self._logger.info("Saving results to {}".format(file_path))
        with PathManager.open(file_path, "w") as f:
            f.write(json.dumps(self._oid_results))
            f.flush()

        if not self._do_evaluation:
            self._logger.info("Annotations are not available for evaluation.")
            return

        self._logger.info("Evaluating predictions ...")
        self._results = OrderedDict()
        res = _evaluate_predictions_on_oid(self._oid_api, file_path, eval_seg=self._mask_on)
        self._results["bbox"] = res

        return copy.deepcopy(self._results)
